{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Tutorial 2: Inside CrypTensors\n",
    "\n",
    "Note: This tutorial is optional, and can be skipped without any loss of continuity to the following tutorials.\n",
    "\n",
    "\n",
    "In this tutorial, we will take a brief look at the internals of ```CrypTensors```. \n",
    "\n",
    "Using the `mpc` backend, a `CrypTensor` is a tensor encrypted using secure MPC protocols, called an `MPCTensor`. In order to support the mathematical operations required by the `MPCTensor`, CrypTen implements two kinds of secret-sharing protocols: arithmetic secret-sharing and binary secret-sharing. Arithmetic secret sharing forms the basis for most of the mathematical operations implemented by `MPCTensor`. Similarly, binary secret-sharing allows for the evaluation of logical expressions.\n",
    "\n",
    "In this tutorial, we'll first introduce the concept of a `CrypTensor` <i>ptype</i> (i.e. <i>private-type</i>), and show how to use it to obtain `MPCTensors` that use arithmetic and binary secret shares. We will also describe how each of these <i>ptypes</i> is used, and how they can be combined to implement desired functionality."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "WARNING:root:module 'torchvision.models.mobilenet' has no attribute 'ConvBNReLU'\n"
     ]
    }
   ],
   "source": [
    "#import the libraries\n",
    "import crypten\n",
    "import torch\n",
    "\n",
    "#initialize crypten\n",
    "crypten.init()\n",
    "#Disables OpenMP threads -- needed by @mpc.run_multiprocess which uses fork\n",
    "torch.set_num_threads(1)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": true
   },
   "source": [
    "## <i>ptype</i> in CrypTen\n",
    "CrypTen defines the `ptype` (for <i>private-type</i>) attribute of an `MPCTensor` to denote the kind of secret-sharing protocol used in the `CrypTensor`. The `ptype` is, in many ways, analogous to the `dtype` of PyTorch. The `ptype` may have two values: \n",
    "\n",
    "- `crypten.mpc.arithmetic` for `ArithmeticSharedTensors`</li>\n",
    "- `crypten.mpc.binary` for  `BinarySharedTensors`</li>\n",
    "\n",
    "We can use the `ptype` attribute to create a `CrypTensor` with the appropriate secret-sharing protocol. For example: "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "x_enc internal type: ptype.arithmetic\n",
      "y_enc internal type: ptype.binary\n"
     ]
    }
   ],
   "source": [
    "#Constructing CrypTensors with ptype attribute\n",
    "\n",
    "#arithmetic secret-shared tensors\n",
    "x_enc = crypten.cryptensor([1.0, 2.0, 3.0], ptype=crypten.mpc.arithmetic)\n",
    "print(\"x_enc internal type:\", x_enc.ptype)\n",
    "\n",
    "#binary secret-shared tensors\n",
    "y = torch.tensor([1, 2, 1], dtype=torch.int32)\n",
    "y_enc = crypten.cryptensor(y, ptype=crypten.mpc.binary)\n",
    "print(\"y_enc internal type:\", y_enc.ptype)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Arithmetic secret-sharing\n",
    "Let's look more closely at the `crypten.mpc.arithmetic` <i>ptype</i>. Most of the mathematical operations implemented by `CrypTensors` are implemented using arithmetic secret sharing. As such, `crypten.mpc.arithmetic` is the default <i>ptype</i> for newly generated `CrypTensors`. \n",
    "\n",
    "Let's begin by creating a new `CrypTensor` using `ptype=crypten.mpc.arithmetic` to enforce that the encryption is done via arithmetic secret sharing. We can print values of each share to confirm that values are being encrypted properly. \n",
    "\n",
    "To do so, we will need to create multiple parties to hold each share. We do this here using the `@mpc.run_multiprocess` function decorator, which we developed to execute crypten code from a single script (as we have in a Jupyter notebook). CrypTen follows the standard MPI programming model: it runs a separate process for each party, but each process runs an identical (complete) program. Each process has a `rank` variable to identify itself.\n",
    "\n",
    "Note that the sum of the two `_tensor` attributes below is equal to a scaled representation of the input. (Because MPC requires values to be integers, we scale input floats to a fixed-point encoding before encryption.)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Rank 0:\n",
      " MPCTensor(\n",
      "\t_tensor=tensor([ 4977840465844292698,   234311463858409737, -8644375101282040029])\n",
      "\tplain_text=HIDDEN\n",
      "\tptype=ptype.arithmetic\n",
      ")\n",
      "\n",
      "\n",
      "Rank 1:\n",
      " MPCTensor(\n",
      "\t_tensor=tensor([-4977840465844227162,  -234311463858278665,  8644375101282236637])\n",
      "\tplain_text=HIDDEN\n",
      "\tptype=ptype.arithmetic\n",
      ")\n",
      "\n"
     ]
    }
   ],
   "source": [
    "import crypten.mpc as mpc\n",
    "import crypten.communicator as comm \n",
    "\n",
    "@mpc.run_multiprocess(world_size=2)\n",
    "def examine_arithmetic_shares():\n",
    "    x_enc = crypten.cryptensor([1, 2, 3], ptype=crypten.mpc.arithmetic)\n",
    "    \n",
    "    rank = comm.get().get_rank()\n",
    "    crypten.print(f\"\\nRank {rank}:\\n {x_enc}\\n\", in_order=True)\n",
    "        \n",
    "x = examine_arithmetic_shares()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Binary secret-sharing\n",
    "The second type of secret-sharing implemented in CrypTen is binary or XOR secret-sharing. This type of secret-sharing allows greater efficiency in evaluating logical expressions. \n",
    "\n",
    "Let's look more closely at the `crypten.mpc.binary` <i>ptype</i>. Most of the logical operations implemented by `CrypTensors` are implemented using arithmetic secret sharing. We typically use this type of secret-sharing when we want to evaluate binary operators (i.e. `^ & | >> <<`, etc.) or logical operations (like comparitors).\n",
    "\n",
    "Let's begin by creating a new `CrypTensor` using `ptype=crypten.mpc.binary` to enforce that the encryption is done via binary secret sharing. We can print values of each share to confirm that values are being encrypted properly, as we did for arithmetic secret-shares.\n",
    "\n",
    "(Note that an xor of the two `_tensor` attributes below is equal to an unscaled version of input.)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Rank 0:\n",
      " MPCTensor(\n",
      "\t_tensor=tensor([-3617348383499570248,  -581960226550774565])\n",
      "\tplain_text=HIDDEN\n",
      "\tptype=ptype.binary\n",
      ")\n",
      "\n",
      "\n",
      "Rank 1:\n",
      " MPCTensor(\n",
      "\t_tensor=tensor([-3617348383499570246,  -581960226550774568])\n",
      "\tplain_text=HIDDEN\n",
      "\tptype=ptype.binary\n",
      ")\n",
      "\n"
     ]
    }
   ],
   "source": [
    "@mpc.run_multiprocess(world_size=2)\n",
    "def examine_binary_shares():\n",
    "    x_enc = crypten.cryptensor([2, 3], ptype=crypten.mpc.binary)\n",
    "    \n",
    "    rank = comm.get().get_rank()\n",
    "    crypten.print(f\"\\nRank {rank}:\\n {x_enc}\\n\", in_order=True)\n",
    "        \n",
    "x = examine_binary_shares()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Using Both Secret-sharing Protocols\n",
    "Quite often a mathematical function may need to use both additive and XOR secret sharing for efficient evaluation.  Functions that require conversions between sharing types include comparators (`>, >=, <, <=, ==, !=`) as well as functions derived from them (`abs, sign, relu`, etc.). For a full list of supported functions, please see the CrypTen documentation.\n",
    "\n",
    "CrypTen provides functionality that allows for the conversion of between <i>ptypes</i>. Conversion between <i>ptypes</i> can be done using the `.to()` function with a `crypten.ptype` input, or by calling the `.arithmetic()` and `.binary()` conversion functions."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "to(crypten.binary):\n",
      "  ptype: ptype.binary\n",
      "  plaintext: tensor([1., 2., 3.])\n",
      "\n",
      "to(crypten.arithmetic):\n",
      "  ptype: ptype.arithmetic\n",
      "  plaintext: tensor([1., 2., 3.])\n",
      "\n"
     ]
    }
   ],
   "source": [
    "from crypten.mpc import MPCTensor\n",
    "\n",
    "@mpc.run_multiprocess(world_size=2)\n",
    "def examine_conversion():\n",
    "    x = torch.tensor([1, 2, 3])\n",
    "    rank = comm.get().get_rank()\n",
    "\n",
    "    # create an MPCTensor with arithmetic secret sharing\n",
    "    x_enc_arithmetic = MPCTensor(x, ptype=crypten.mpc.arithmetic)\n",
    "    \n",
    "    # To binary\n",
    "    x_enc_binary = x_enc_arithmetic.to(crypten.mpc.binary)\n",
    "    x_from_binary = x_enc_binary.get_plain_text()\n",
    "    \n",
    "    # print only once\n",
    "    crypten.print(\"to(crypten.binary):\")\n",
    "    crypten.print(f\"  ptype: {x_enc_binary.ptype}\\n  plaintext: {x_from_binary}\\n\")\n",
    "\n",
    "        \n",
    "    # To arithmetic\n",
    "    x_enc_arithmetic = x_enc_arithmetic.to(crypten.mpc.arithmetic)\n",
    "    x_from_arithmetic = x_enc_arithmetic.get_plain_text()\n",
    "    \n",
    "    # print only once\n",
    "    crypten.print(\"to(crypten.arithmetic):\")\n",
    "    crypten.print(f\"  ptype: {x_enc_arithmetic.ptype}\\n  plaintext: {x_from_arithmetic}\\n\")\n",
    "\n",
    "        \n",
    "z = examine_conversion()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Data Sources\n",
    "CrypTen follows the standard MPI programming model: it runs a separate process for each party, but each process runs an identical (complete) program. Each process has a `rank` variable to identify itself.\n",
    "\n",
    "If the process with rank `i` is the source of data `x`, then `x` gets encrypted with `i` as its source value (denoted as `src`). However, MPI protocols require that both processes to provide a tensor with the same size as their input. CrypTen ignores all data provided from non-source processes when encrypting.\n",
    "\n",
    "In the next example, we'll show how to use the `rank` and `src` values to encrypt tensors. Here, we will have each of 3 parties generate a value `x` which is equal to its own `rank` value. Within the loop, 3 encrypted tensors are created, each with a different source. When these tensors are decrypted, we can verify that the tensors are generated using the tensor provided by the source process.\n",
    "\n",
    "(Note that `crypten.cryptensor` uses rank 0 as the default source if none is provided.)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Rank 0: 0\n",
      "Rank 1: 1\n",
      "Rank 2: 2\n",
      "Source 0: 0.0\n",
      "Source 1: 1.0\n",
      "Source 2: 2.0\n"
     ]
    }
   ],
   "source": [
    "@mpc.run_multiprocess(world_size=3)\n",
    "def examine_sources():\n",
    "    # Create a different tensor on each rank\n",
    "    rank = comm.get().get_rank()\n",
    "    x = torch.tensor(rank)\n",
    "    crypten.print(f\"Rank {rank}: {x}\", in_order=True)\n",
    "    \n",
    "    # \n",
    "    world_size = comm.get().get_world_size()\n",
    "    for i in range(world_size):\n",
    "        x_enc = crypten.cryptensor(x, src=i)\n",
    "        z = x_enc.get_plain_text()\n",
    "        \n",
    "        # Only print from one process to avoid duplicates\n",
    "        crypten.print(f\"Source {i}: {z}\")\n",
    "        \n",
    "x = examine_sources()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
